// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <cstdint>
#define REDUCE_OP PoolType::SUM
#define REDUCE_DIM ReduceDim::REDUCE_ROW

#define BCAST_LLKOP EltwiseBinaryType::ELWMUL
#define BCAST_DIM BroadcastType::COL

#include "compute_kernel_api.h"
#include "compute_kernel_api/reduce.h"
#include "compute_kernel_api/bcast.h"
#include "compute_kernel_api/eltwise_binary.h"
#include "compute_kernel_api/layernorm.h"
#include "compute_kernel_api/eltwise_binary_sfpu.h"
#include "compute_kernel_api/tile_move_copy.h"
#include "ttnn/operations/normalization/kernel_util/compute/numeric.h"

namespace NAMESPACE {

void MAIN {
    namespace kutil = norm::kernel_util;
    namespace numeric = kutil::compute::numeric;
    namespace policies = kutil::compute::policies;

    uint32_t NCHt = get_arg_val<uint32_t>(0);
    constexpr uint32_t Wt = get_compile_time_arg_val(0);
    constexpr uint32_t blk = get_compile_time_arg_val(1);
    constexpr uint32_t do_gamma = get_compile_time_arg_val(2);
    constexpr uint32_t do_beta = get_compile_time_arg_val(3);
    constexpr bool FLOAT32_DTYPE = get_compile_time_arg_val(4) == 1;
    constexpr bool FLOAT32_REDUCTION = get_compile_time_arg_val(5) == 1;
    constexpr bool LEGACY_RSQRT = get_compile_time_arg_val(6) == 1;
    constexpr uint32_t one_over_W = get_compile_time_arg_val(7);

    constexpr uint32_t onetile = 1;
    // reserve one tile for zeros on cb_in2
    // TODO(AP): check that if DST is indeed zeroed by release_dst (and initially), we can use it as zeroes

    // Note that the entire W dimension must fit in the intermed0 CB for this kernel to be correct
    constexpr auto cb_scaler = tt::CBIndex::c_2;  // single tile generated by the reader
    constexpr auto cb_eps = tt::CBIndex::c_3;     // single tile generated by the reader
    constexpr auto cb_in = tt::CBIndex::c_0;      // input x or a for fused pre-add (x=a+b)
    constexpr auto cb_inb = tt::CBIndex::c_1;     // input b for fused pre-add
    constexpr auto cb_out = tt::CBIndex::c_16;    // output
    constexpr auto cb_gamma = tt::CBIndex::c_5;
    constexpr auto cb_beta = tt::CBIndex::c_6;
    uint32_t cb_xmm = tt::CBIndex::c_24;          // x minus mean
    constexpr auto cb_ex = tt::CBIndex::c_18;     // E[x]
    constexpr auto cb_ex2 = tt::CBIndex::c_19;    // E[(x-E[x])^2]
    constexpr auto cb_xmm2 = tt::CBIndex::c_20;   // xmm^2
    constexpr auto cb_ex2pe = tt::CBIndex::c_21;  // E[(x-E[x])^2]+eps
    uint32_t cb_fusion = tt::CBIndex::c_22;       // stream gamma/beta
    constexpr auto scaler0 = 0;
    constexpr auto cb_accumulate = tt::CBIndex::c_26;  // For accumulating (x-E[x])^2
#ifdef FUSE_PRE_ADD
#ifdef RMSNORM
    constexpr uint32_t cb_x = cb_xmm;
#else
    constexpr uint32_t cb_x = tt::CBIndex::c_23;
#endif
#else
    constexpr uint32_t cb_x = cb_in;
#endif

#ifdef FUSE_PRE_ADD
    binary_op_init_common(cb_in, cb_inb, cb_x);
#else
    binary_op_init_common(cb_in, cb_scaler, cb_ex);
#endif
    cb_wait_front(cb_scaler, 1);  // comes from the reader
    cb_wait_front(cb_eps, 1);     // comes from the reader

    for (uint32_t ncht = 0; ncht < NCHt; ncht++) {
        constexpr int onetile = 1;
        constexpr int dst0 = 0;
#ifndef RMSNORM
        // Start of
        //  E[x]
        //  aka   ∑(x)
        //      --------
        //         n
#ifdef FUSE_PRE_ADD
        numeric::row_wise_mean_with_pre_add<FLOAT32_REDUCTION, policies::PopInputPolicy::POP>(
            cb_in, cb_inb, cb_scaler, cb_ex, one_over_W, Wt, blk);
#else
        numeric::row_wise_mean<FLOAT32_REDUCTION, policies::PopInputPolicy::POP>(
            cb_in, cb_scaler, cb_ex, one_over_W, Wt, blk);
#endif
#endif  // !RMS ifdef end
        // Start of
        // Var Calculation
        // Var(X) = ∑(x-E[x])^2
        //         -----------
        //              n
        for (uint32_t wt = 0; wt < Wt; wt += blk) {
            tile_regs_acquire();
            cb_wait_front(cb_in, blk);
#ifdef RMSNORM
            reconfig_data_format_srca(cb_in);
            copy_tile_init(cb_in);
            for (uint32_t j = 0; j < blk; j++) {
                copy_tile(cb_in, j, j);
            }
#else
            // x-E[x]
            reconfig_data_format(cb_in, cb_ex);
            sub_bcast_cols_init_short(cb_in, cb_ex);
            for (uint32_t j = 0; j < blk; j++) {
                sub_tiles_bcast_cols(cb_in, cb_ex, j, 0, j);
            }
#endif
            cb_pop_front(cb_in, blk);
#ifdef FUSE_PRE_ADD
            cb_wait_front(cb_inb, blk);
            reconfig_data_format_srca(cb_in, cb_inb);
            binary_dest_reuse_tiles_init<ELWADD, EltwiseBinaryReuseDestType::DEST_TO_SRCB>(cb_inb);
            for (uint32_t j = 0; j < blk; j++) {
                binary_dest_reuse_tiles<ELWADD, EltwiseBinaryReuseDestType::DEST_TO_SRCB>(cb_inb, j, j);
            }
            cb_pop_front(cb_inb, blk);
#endif
            // (x-E[x])^2. Pack to CB
            square_tile_init();
            for (uint32_t j = 0; j < blk; j++) {
                square_tile(j);
            }
            tile_regs_commit();
            tile_regs_wait();
            cb_reserve_back(cb_xmm2, blk);
            pack_reconfig_data_format(cb_xmm2);
            for (uint32_t j = 0; j < blk; j++) {
                pack_tile(j, cb_xmm2);
            }
            tile_regs_release();
            cb_push_back(cb_xmm2, blk);

            tile_regs_acquire();
            if (wt > 0) {
                cb_wait_front(cb_accumulate, onetile);
                reconfig_data_format_srca(cb_accumulate);
                copy_tile_init(cb_accumulate);
                copy_tile(cb_accumulate, 0, dst0);
                cb_pop_front(cb_accumulate, onetile);
            }
            cb_wait_front(cb_xmm2, blk);

            // Accumulate (x-E[x])^2
            reconfig_data_format(cb_xmm2, cb_scaler);
            reduce_init<REDUCE_OP, REDUCE_DIM, FLOAT32_REDUCTION>(cb_xmm2, cb_scaler, cb_accumulate);
            for (uint32_t j = 0; j < blk; j++) {
                reduce_tile<REDUCE_OP, REDUCE_DIM, FLOAT32_REDUCTION>(cb_xmm2, cb_scaler, j, scaler0, dst0);
            }

            cb_pop_front(cb_xmm2, blk);

            const auto final_iter = wt == Wt - blk;
            const auto pack_cb = final_iter ? cb_ex2 : cb_accumulate;
            if (final_iter) {
                // Divide by W
                binop_with_scalar_tile_init();
                mul_unary_tile(dst0, one_over_W);
            }

            reduce_uninit<FLOAT32_REDUCTION>();
            tile_regs_commit();
            tile_regs_wait();

            cb_reserve_back(pack_cb, onetile);
            pack_reconfig_data_format(pack_cb);
            pack_tile(dst0, pack_cb);
            tile_regs_release();
            cb_push_back(pack_cb, onetile);
        }

        // End of
        // Var Calculation
        // Var(X) = ∑(x-E[x])^2
        //         -----------

        // Start of
        // Calculation
        //                     1
        //  cb_ex2pe =   -------------
        //               √(Var(X) + ε)
        cb_wait_front(cb_ex2, onetile);
        reconfig_data_format(cb_ex2, cb_eps);
        tile_regs_acquire();

        add_tiles_init(cb_ex2, cb_eps);
        add_tiles(cb_ex2, cb_eps, 0, 0, dst0);

        rsqrt_tile_init<LEGACY_RSQRT>();
        rsqrt_tile<LEGACY_RSQRT>(dst0);

        tile_regs_commit();
        tile_regs_wait();
        pack_reconfig_data_format(cb_ex2pe);
        pack_tile(dst0, cb_ex2pe);
        tile_regs_release();
        cb_push_back(cb_ex2pe, onetile);
        cb_pop_front(cb_ex2, onetile);

        // broadcasts the tile since cb_ex2pe is a column vector that contains the important data
        cb_wait_front(cb_ex2pe, onetile);
        tile_regs_acquire();
        reconfig_data_format_srca(cb_ex2pe);

        unary_bcast_init<BroadcastType::COL>(cb_ex2pe, cb_ex2pe);
        unary_bcast<BroadcastType::COL>(cb_ex2pe, 0, dst0);
        cb_pop_front(cb_ex2pe, onetile);

        tile_regs_commit();
        tile_regs_wait();
        pack_tile(dst0, cb_ex2pe);
        tile_regs_release();
        cb_push_back(cb_ex2pe, onetile);
        cb_wait_front(cb_ex2pe, onetile);

        // End of
        // Calculation
        //                     1
        //  cb_ex2pe =   -------------
        //               √(Var(X) + ε)

        // Start of
        // Final Val Calc
        //    x-E[X]
        //(---------------*𝛄)+ß
        //  √(Var(X)+ε)
        for (uint32_t wt = 0; wt < Wt; wt += blk) {
            tile_regs_acquire();
            cb_wait_front(cb_ex, 1);
            cb_wait_front(cb_in, blk);
#ifdef RMSNORM
            reconfig_data_format_srca(cb_in);
            copy_tile_init(cb_in);
            for (uint32_t j = 0; j < blk; j++) {
                copy_tile(cb_in, j, j);
            }
#else
            reconfig_data_format(cb_in, cb_ex);
            sub_bcast_cols_init_short(cb_in, cb_ex);
            // x-E[x]
            for (uint32_t j = 0; j < blk; j++) {
                sub_tiles_bcast_cols(cb_in, cb_ex, j, 0, j);
            }
#endif
            cb_pop_front(cb_in, blk);
#ifdef FUSE_PRE_ADD
            cb_wait_front(cb_inb, blk);
            reconfig_data_format_srca(cb_inb);
            binary_dest_reuse_tiles_init<ELWADD, EltwiseBinaryReuseDestType::DEST_TO_SRCB>(cb_inb);
            for (uint32_t j = 0; j < blk; j++) {
                binary_dest_reuse_tiles<ELWADD, EltwiseBinaryReuseDestType::DEST_TO_SRCB>(cb_inb, j, j);
            }
            cb_pop_front(cb_inb, blk);
#endif
            tile_regs_commit();
            tile_regs_wait();
            // Note: We shouldn't have to pack to
            // intermediate CB. We should be able to
            // do a binary dest with reuse (as we used
            // to). However, tt-llk #868 is preventing
            // that from working at the moment.
            cb_reserve_back(cb_xmm, blk);
            pack_reconfig_data_format(cb_xmm);
            for (uint32_t j = 0; j < blk; j++) {
                pack_tile(j, cb_xmm);
            }
            cb_push_back(cb_xmm, blk);
            tile_regs_release();

            cb_wait_front(cb_xmm, blk);
            reconfig_data_format(cb_xmm, cb_ex2pe);
            tile_regs_acquire();

            mul_tiles_init(cb_xmm, cb_ex2pe);
            for (uint32_t j = 0; j < blk; j++) {
                mul_tiles(cb_xmm, cb_ex2pe, j, 0, j);
            }
            tile_regs_commit();
            tile_regs_wait();

            if constexpr (!(do_gamma == 1 or do_beta == 1)) {
                cb_fusion = cb_out;
            }
            cb_reserve_back(cb_fusion, blk);
            pack_reconfig_data_format(cb_fusion);
            for (uint32_t j = 0; j < blk; j++) {
                pack_tile(j, cb_fusion);
            }
            tile_regs_release();
            cb_push_back(cb_fusion, blk);
            cb_pop_front(cb_xmm, blk);

            if constexpr (do_gamma == 1) {
                tile_regs_acquire();
                tile_regs_wait();
                reconfig_data_format(cb_fusion, cb_gamma);
                if constexpr (!do_beta) {
                    pack_reconfig_data_format(cb_out);
                }
                cb_wait_front(cb_gamma, blk);
                cb_wait_front(cb_fusion, blk);
                mul_bcast_rows_init_short(cb_fusion, cb_gamma);
                for (uint32_t j = 0; j < blk; j++) {
                    mul_tiles_bcast_rows(cb_fusion, cb_gamma, j, j, j);
                }
                tile_regs_commit();
                cb_pop_front(cb_gamma, blk);
                cb_pop_front(cb_fusion, blk);
                if constexpr (!do_beta) {
                    cb_reserve_back(cb_out, blk);
                    for (uint32_t j = 0; j < blk; j++) {
                        pack_tile(j, cb_out);
                    }
                    cb_push_back(cb_out, blk);
                } else {
                    cb_reserve_back(cb_fusion, blk);
                    for (uint32_t j = 0; j < blk; j++) {
                        pack_tile(j, cb_fusion);
                    }
                    cb_push_back(cb_fusion, blk);
                }

                tile_regs_release();
            }
            if constexpr (do_beta == 1) {
                tile_regs_acquire();
                tile_regs_wait();
                reconfig_data_format(cb_fusion, cb_beta);
                pack_reconfig_data_format(cb_out);
                cb_wait_front(cb_beta, blk);
                cb_wait_front(cb_fusion, blk);
                add_bcast_rows_init_short(cb_fusion, cb_beta);
                for (uint32_t j = 0; j < blk; j++) {
                    add_tiles_bcast_rows(cb_fusion, cb_beta, j, j, j);
                }
                tile_regs_commit();
                cb_pop_front(cb_beta, blk);
                cb_pop_front(cb_fusion, blk);
                cb_reserve_back(cb_out, blk);
                for (uint32_t j = 0; j < blk; j++) {
                    pack_tile(j, cb_out);
                }
                tile_regs_release();
                cb_push_back(cb_out, blk);
            }
        }
        // End of
        // Final Val Calc
        //    x-E[X]
        //(---------------*𝛄)+ß
        //  √(Var(X)+ε)
        cb_pop_front(cb_ex, onetile);
        cb_pop_front(cb_ex2pe, onetile);
    }  // NCHt loop
}
}  // namespace NAMESPACE
